// Copyright Contributors to the OpenVDB Project
// SPDX-License-Identifier: MPL-2.0

/// @file points/AttributeArray.cc

#include "AttributeArray.h"
#include <map>

namespace openvdb {
OPENVDB_USE_VERSION_NAMESPACE
namespace OPENVDB_VERSION_NAME {
namespace points {


////////////////////////////////////////


namespace {

using AttributeFactoryMap = std::map<NamePair, AttributeArray::FactoryMethod>;

struct LockedAttributeRegistry
{
    tbb::spin_mutex     mMutex;
    AttributeFactoryMap mMap;
};

// Global function for accessing the registry
LockedAttributeRegistry*
getAttributeRegistry()
{
    static LockedAttributeRegistry registry;
    return &registry;
}

} // unnamed namespace


////////////////////////////////////////

// AttributeArray::ScopedRegistryLock implementation

AttributeArray::ScopedRegistryLock::ScopedRegistryLock()
    : lock(getAttributeRegistry()->mMutex)
{
}


////////////////////////////////////////

// AttributeArray implementation


AttributeArray::AttributeArray(const AttributeArray& rhs)
    : AttributeArray(rhs, tbb::spin_mutex::scoped_lock(rhs.mMutex))
{
}


AttributeArray::AttributeArray(const AttributeArray& rhs, const tbb::spin_mutex::scoped_lock&)
    : mIsUniform(rhs.mIsUniform)
    , mFlags(rhs.mFlags)
    , mUsePagedRead(rhs.mUsePagedRead)
    , mOutOfCore(rhs.mOutOfCore.load())
    , mPageHandle()
{
    if (mFlags & PARTIALREAD)       mCompressedBytes = rhs.mCompressedBytes;
    else if (rhs.mPageHandle)       mPageHandle = rhs.mPageHandle->copy();
}


AttributeArray&
AttributeArray::operator=(const AttributeArray& rhs)
{
    // if this AttributeArray has been partially read, zero the compressed bytes,
    // so the page handle won't attempt to clean up invalid memory
    if (mFlags & PARTIALREAD)       mCompressedBytes = 0;
    mIsUniform = rhs.mIsUniform;
    mFlags = rhs.mFlags;
    mUsePagedRead = rhs.mUsePagedRead;
    mOutOfCore.store(rhs.mOutOfCore);
    if (mFlags & PARTIALREAD)       mCompressedBytes = rhs.mCompressedBytes;
    else if (rhs.mPageHandle)       mPageHandle = rhs.mPageHandle->copy();
    else                            mPageHandle.reset();
    return *this;
}


AttributeArray::Ptr
AttributeArray::create(const NamePair& type, Index length, Index stride,
    bool constantStride, const Metadata* metadata, const ScopedRegistryLock* lock)
{
    auto* registry = getAttributeRegistry();
    tbb::spin_mutex::scoped_lock _lock;
    if (!lock)  _lock.acquire(registry->mMutex);

    auto iter = registry->mMap.find(type);
    if (iter == registry->mMap.end()) {
        OPENVDB_THROW(LookupError,
            "Cannot create attribute of unregistered type " << type.first << "_" << type.second);
    }
    return (iter->second)(length, stride, constantStride, metadata);
}


bool
AttributeArray::isRegistered(const NamePair& type, const ScopedRegistryLock* lock)
{
    LockedAttributeRegistry* registry = getAttributeRegistry();
    tbb::spin_mutex::scoped_lock _lock;
    if (!lock)  _lock.acquire(registry->mMutex);
    return (registry->mMap.find(type) != registry->mMap.end());
}


void
AttributeArray::clearRegistry(const ScopedRegistryLock* lock)
{
    LockedAttributeRegistry* registry = getAttributeRegistry();
    tbb::spin_mutex::scoped_lock _lock;
    if (!lock)  _lock.acquire(registry->mMutex);
    registry->mMap.clear();
}


void
AttributeArray::registerType(const NamePair& type, FactoryMethod factory, const ScopedRegistryLock* lock)
{
    { // check the type of the AttributeArray generated by the factory method
        auto array = (*factory)(/*length=*/0, /*stride=*/0, /*constantStride=*/false, /*metadata=*/nullptr);
        const NamePair& factoryType = array->type();
        if (factoryType != type) {
            OPENVDB_THROW(KeyError, "Attribute type " << type.first << "_" << type.second
                << " does not match the type created by the factory method "
                << factoryType.first << "_" << factoryType.second << ".");
        }
    }

    LockedAttributeRegistry* registry = getAttributeRegistry();
    tbb::spin_mutex::scoped_lock _lock;
    if (!lock)  _lock.acquire(registry->mMutex);

    registry->mMap[type] = factory;
}


void
AttributeArray::unregisterType(const NamePair& type, const ScopedRegistryLock* lock)
{
    LockedAttributeRegistry* registry = getAttributeRegistry();
    tbb::spin_mutex::scoped_lock _lock;
    if (!lock)  _lock.acquire(registry->mMutex);

    registry->mMap.erase(type);
}


void
AttributeArray::setTransient(bool state)
{
    if (state) mFlags = static_cast<uint8_t>(mFlags | Int16(TRANSIENT));
    else       mFlags = static_cast<uint8_t>(mFlags & ~Int16(TRANSIENT));
}


void
AttributeArray::setHidden(bool state)
{
    if (state) mFlags = static_cast<uint8_t>(mFlags | Int16(HIDDEN));
    else       mFlags = static_cast<uint8_t>(mFlags & ~Int16(HIDDEN));
}


void
AttributeArray::setStreaming(bool state)
{
    if (state) mFlags = static_cast<uint8_t>(mFlags | Int16(STREAMING));
    else       mFlags = static_cast<uint8_t>(mFlags & ~Int16(STREAMING));
}


void
AttributeArray::setConstantStride(bool state)
{
    if (state) mFlags = static_cast<uint8_t>(mFlags | Int16(CONSTANTSTRIDE));
    else       mFlags = static_cast<uint8_t>(mFlags & ~Int16(CONSTANTSTRIDE));
}


bool
AttributeArray::operator==(const AttributeArray& other) const
{
    this->loadData();
    other.loadData();

    if (this->mUsePagedRead != other.mUsePagedRead ||
        this->mFlags != other.mFlags) return false;
    return this->isEqual(other);
}

} // namespace points
} // namespace OPENVDB_VERSION_NAME
} // namespace openvdb
